-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP tests and checkpointing fixes #26180
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Just a short feedback, when trying to resume from a checkpoint with SHARDED_STATE_DICT (see #26186 for setup/details) with this PR, i get a Cuda OOM error, full stacktrace below. Full Stacktrace File "/workspace/axolotl/scripts/finetune.py", line 287, in <module>
Traceback (most recent call last):
File "/workspace/axolotl/scripts/finetune.py", line 287, in <module>
Traceback (most recent call last):
fire.Fire(do_cli)
File "/workspace/axolotl/scripts/finetune.py", line 287, in <module>
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
fire.Fire(do_cli)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
fire.Fire(do_cli)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component = fn(*varargs, **kwargs)
File "/workspace/axolotl/scripts/finetune.py", line 283, in do_cli
component, remaining_args = _CallAndUpdateTrace(
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component, remaining_args = _CallAndUpdateTrace(train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
File "/workspace/axolotl/src/axolotl/train.py", line 116, in train
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/workspace/axolotl/scripts/finetune.py", line 283, in do_cli
trainer.train(resume_from_checkpoint=resume_from_checkpoint)component = fn(*varargs, **kwargs)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1575, in train
File "/workspace/axolotl/scripts/finetune.py", line 283, in do_cli
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
File "/workspace/axolotl/src/axolotl/train.py", line 116, in train
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
File "/workspace/axolotl/src/axolotl/train.py", line 116, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1575, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1575, in train
return inner_training_loop(
return inner_training_loop( File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1876, in _inner_training_loop
return inner_training_loop(
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1876, in _inner_training_loop
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1876, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2768, in training_step
tr_loss_step = self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2768, in training_step
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2768, in training_step
self.accelerator.backward(loss)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1963, in backward
self.scaler.scale(loss).backward(**kwargs)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
self.accelerator.backward(loss)
self.accelerator.backward(loss)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1963, in backward
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1963, in backward
torch.autograd.backward(
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
self.scaler.scale(loss).backward(**kwargs)self.scaler.scale(loss).backward(**kwargs)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
torch.autograd.backward(torch.autograd.backward(
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
return user_fn(self, *args)Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward passVariable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 157, in backward
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
return user_fn(self, *args)return user_fn(self, *args)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 157, in backward
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 157, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB (GPU 3; 31.74 GiB total capacity; 30.59 GiB already allocated; 168.38 MiB free; 31.12 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONFVariable._execution_engine.run_backward( # Calls into the C++ engine to run the backward passVariable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass |
Hello @jphme, I do notice an increase in GPU memory consumption of about 600MB for above tests when resuming from checkpoint saved via |
Hi @pacman100 sure - but just to clarify: The training started (and ran until the checkpoint) without problems and its also possible to extract the model after the training with So you mean that specifically for restarting from a This is quite dangerous as everyone tunes their runs so VRAM is maxed and that would mean that many runs can't be restarted from checkpoints... EDIT: Ok I re-read your post - in my case the checkpoint was indeed created with the main branch and I only tried to restart with this PR; if the PR generally increases VRAM consumption that would explain it. But then I don't understand whats exactly the Pytorch issue. And is there no way (with offloading) to avoid the increased VRAM consumption as everything besides checkpointing (training, model extraction) worked fine for me? (Sorry if i am a bit slow understanding, still new to FSDP/Torch - many thanks for your work on this!) |
This PR doesn't increase VRAM consumption. Internally, it is calling the Torch utility here: and here: These are probably leading to the increased VRAM consumption. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left a few nits, mostly on the test 🤗
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
So just for further reference (because other people are starting to have the same issue and commented on my closed issue): Checkpoints are currently of no use with Will try with torch nightly if I have the opportunity (there seems to be a new env that could help), unfortunately very busy currently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking forward to the fixes this brings!
* add fsdp tests * Update test_fsdp.py * Update test_fsdp.py * fixes * checks * Update trainer.py * fix * fixes for saving/resuming checkpoints * fixes * add tests and delete debug statements * fixing tests * Update test_fsdp.py * fix tests * fix tests * minor nits * fix code style and quality * refactor and modularize test code * reduce the time of tests * reduce the test time * fix test * reduce test time * reduce test time * fix failing tests * fix * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * resolve comments --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* add fsdp tests * Update test_fsdp.py * Update test_fsdp.py * fixes * checks * Update trainer.py * fix * fixes for saving/resuming checkpoints * fixes * add tests and delete debug statements * fixing tests * Update test_fsdp.py * fix tests * fix tests * minor nits * fix code style and quality * refactor and modularize test code * reduce the time of tests * reduce the test time * fix test * reduce test time * reduce test time * fix failing tests * fix * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * resolve comments --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Hi, has this fix been merged into the new the new transformers v4.33.3? |
Hey @jmzeng, it is not part of v4.33.3 but will be part of v4.34.0 which will be released early next week. In the meantime, you can install from source:
|
What does this PR do?
Below we will run the different combinations of FSDP
SHARDING_STRATEGY
andSTATE_DICT_TYPE
for therun_glue.py
transformers exampleInitial setup:
a. FULL_SHARD + FULL_STATE_DICT
i. command to run:
Kill the process after epoch 1. Run the above command with --resume_from_checkpoint as below:
iii. Plots of loss and learning rate:
b. SHARD_GRAD_OP + FULL_STATE_DICT
Same as above but with the following cmd arg
--fsdp "shard_grad_op auto_wrap"
Plots:
c. FULL_SHARD + SHARDED_STATE_DICT
i. Here, we will need to use the accelerate launcher as the option to choose
SHARDED_STATE_DICT
is currently available viaaccelerate config
. Below is the config filefsdp_config.yaml
:ii. command to run:
Kill the process after epoch 1. Run the above command with --resume_from_checkpoint as below:
iii. Plots:
d. SHARD_GRAD_OP + SHARDED_STATE_DICT
Just run the
accelerate config
command and chooseSHARD_GRAD_OP
Sharding strategy and getfsdp_config.yaml
similar to the above case. The rest is the same.Plots: